package standalone

import (
	"github.com/DiracLee/dires-go/app/cmdline"
	"github.com/DiracLee/dires-go/app/connection"
	"github.com/DiracLee/dires-go/app/database"
	"github.com/DiracLee/dires-go/app/handler"
	"github.com/DiracLee/dires-go/app/payload"
	"github.com/DiracLee/dires-go/ds/set"
	"strings"
)

func handleMulti(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) != 0 {
		return payload.NewArgNumErrPayload(cmdline.CmdMulti)
	}
	return StartMulti(conn)
}

func handleWatch(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) == 0 {
		return payload.NewArgNumErrPayload(cmdline.CmdWatch)
	}
	db, ok := selectDB(svr, conn)
	if !ok {
		return payload.NewErrPayload("ERR Storage Index is out of range")
	}
	watching := conn.GetWatching()
	for _, arg := range args {
		key := string(arg)
		watching[key] = db.GetVersion(key)
	}
	return payload.NewOkPayload()
}

func handleExec(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) != 1 {
		return payload.NewArgNumErrPayload(cmdline.CmdExec)
	}
	db, ok := selectDB(svr, conn)
	if !ok {
		return payload.NewErrPayload("ERR Storage Index is out of range")
	}
	return execMulti(db, conn)
}

func handleDiscard(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) != 0 {
		return payload.NewArgNumErrPayload(cmdline.CmdDiscard)
	}
	return DiscardMulti(conn)
}

func StartMulti(conn connection.Connection) payload.Payload {
	if conn.InMultiState() {
		return payload.NewErrPayload("ERR MULTI calls can not be nested")
	}
	conn.SetMultiState(true)
	return payload.NewOkPayload()
}

func DiscardMulti(conn connection.Connection) payload.Payload {
	if !conn.InMultiState() {
		return payload.NewErrPayload("ERR DISCARD without MULTI")
	}
	conn.ClearQueuedCmds()
	conn.SetMultiState(false)
	return payload.NewOkPayload()
}

func isWatchingChanged(db database.Storage, watching map[string]int64) bool {
	for key, ver := range watching {
		currentVersion := db.GetVersion(key)
		if ver != currentVersion {
			return true
		}
	}
	return false
}

func execMulti(db database.DB, conn connection.Connection) payload.Payload {
	if !conn.InMultiState() {
		return payload.NewErrPayload("ERR EXEC without MULTI")
	}
	defer conn.SetMultiState(false)
	cmdLines := conn.GetQueuedCmdLine()
	watching := conn.GetWatching()

	// prepare
	writeKeys := make([]string, 0) // may contains duplicate
	readKeys := make([]string, 0)
	for _, cmdLine := range cmdLines {
		cmdName := strings.ToLower(string(cmdLine[0]))
		cmd := handler.CmdTable[cmdName]
		prepare := cmd.Prepare
		write, read := prepare(cmdLine[1:])
		writeKeys = append(writeKeys, write...)
		readKeys = append(readKeys, read...)
	}
	// set watch
	watchingKeys := make([]string, 0, len(watching))
	for key := range watching {
		watchingKeys = append(watchingKeys, key)
	}
	readKeys = append(readKeys, watchingKeys...)
	db.RWLock(writeKeys, readKeys)
	defer db.RWUnlock(writeKeys, readKeys)

	if isWatchingChanged(db, watching) { // watching keys changed, abort
		return payload.NewEmptyMultiBulkPayload()
	}
	// execute
	results := make([]payload.Payload, 0, len(cmdLines))
	aborted := false
	undoCmdLines := make([][]cmdline.CmdLine, 0, len(cmdLines))
	for _, cmdLine := range cmdLines {
		undoCmdLines = append(undoCmdLines, GetUndoLogs(db, cmdLine))
		result := execWithLock(db, cmdLine)
		if payload.IsErrorPayload(result) {
			aborted = true
			// don't rollback failed commands
			undoCmdLines = undoCmdLines[:len(undoCmdLines)-1]
			break
		}
		results = append(results, result)
	}
	if !aborted { //success
		db.AddVersion(writeKeys...)
		return payload.NewMultiPayload(results)
	}
	// undo if aborted
	size := len(undoCmdLines)
	for i := size - 1; i >= 0; i-- {
		curCmdLines := undoCmdLines[i]
		if len(curCmdLines) == 0 {
			continue
		}
		for _, cmdLine := range curCmdLines {
			execWithLock(db, cmdLine)
		}
	}
	return payload.NewErrPayload("EXECABORT Transaction discarded because of previous errors.")
}

func execWithLock(db database.DB, cmdLine [][]byte) payload.Payload {
	cmdName := strings.ToLower(string(cmdLine[0]))
	cmd, ok := handler.CmdTable[cmdName]
	if !ok {
		return payload.NewErrPayload("ERR unknown command '" + cmdName + "'")
	}
	if !validateArity(cmd.Arity, cmdLine) {
		return payload.NewArgNumErrPayload(cmdName)
	}
	fun := cmd.Executor
	return fun(db, cmdLine[1:])
}

var forbiddenInMulti = set.NewSample(
	cmdline.CmdFlushDB, cmdline.CmdFlushAll,
)

// EnqueueCmd puts command line into `multi` pending queue
func EnqueueCmd(conn connection.Connection, cmdLine [][]byte) payload.Payload {
	cmdName := strings.ToLower(string(cmdLine[0]))
	cmd, ok := handler.CmdTable[cmdName]
	if !ok {
		return payload.NewErrPayload("ERR unknown command '" + cmdName + "'")
	}
	if forbiddenInMulti.Has(cmdName) {
		return payload.NewErrPayload("ERR command '" + cmdName + "' cannot be used in MULTI")
	}
	if cmd.Prepare == nil {
		return payload.NewErrPayload("ERR command '" + cmdName + "' cannot be used in MULTI")
	}
	if !validateArity(cmd.Arity, cmdLine) {
		// difference with redis: we won't enqueue command line with wrong arity
		return payload.NewArgNumErrPayload(cmdName)
	}
	conn.EnqueueCmd(cmdLine)
	return payload.NewQueuedPayload()
}

func GetUndoLogs(db database.DB, cmdLine cmdline.CmdLine) []cmdline.CmdLine {
	cmdName := strings.ToLower(string(cmdLine[0]))
	cmd, exists := handler.CmdTable[cmdName]
	if !exists {
		return nil
	}
	undo := cmd.Undo
	if undo == nil {
		return nil
	}
	return cmd.Undo(db, cmdLine[1:])
}
